{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "87157405-b0ed-4435-ba00-6bc436919e4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec398cad-158a-4497-a684-0b16b1a22555",
   "metadata": {},
   "source": [
    "### Check available device info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dd0ea5a-543b-44d4-8ce4-24405d4cdcdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "    #print(\"Using GPU\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    #print(\"Using CPU\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e7a6ac1-5bf7-4275-9d12-4f3ffb56ae67",
   "metadata": {},
   "source": [
    "### Load model and tokenizer from HuggingFace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "19ece019-0d5c-4df9-bb96-99dd09a2b1cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: 100%|██████████| 283/283 [00:00<00:00, 185kB/s]\n",
      "Downloading: 100%|██████████| 2.11M/2.11M [00:00<00:00, 42.2MB/s]\n",
      "Downloading: 100%|██████████| 1.08k/1.08k [00:00<00:00, 735kB/s]\n",
      "Downloading: 100%|██████████| 99.0/99.0 [00:00<00:00, 65.0kB/s]\n",
      "Downloading: 100%|██████████| 1.09k/1.09k [00:00<00:00, 774kB/s]\n",
      "Downloading: 100%|██████████| 22.3k/22.3k [00:00<00:00, 14.0MB/s]\n",
      "Downloading: 100%|██████████| 9.98G/9.98G [01:53<00:00, 88.0MB/s]\n",
      "Downloading: 100%|██████████| 1.26G/1.26G [00:17<00:00, 73.4MB/s]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model_name = \"shailja/CodeGen_6B_Verilog\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"shailja/fine-tuned-codegen-6B-Verilog\")\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\"shailja/fine-tuned-codegen-6B-Verilog\").to(device)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7aa4dd65-d0cc-4439-9bcb-537017a97f23",
   "metadata": {},
   "source": [
    "### Sampling Verilog for the task specificd with prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "87e819c1-bec2-4c16-b581-4f6a44e151e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "//a half adder module \n",
      "\n",
      "module half_adder(\n",
      "    input a,\n",
      "    input b,\n",
      "    output sum,\n",
      "    output carry\n",
      "    );\n",
      "\n",
      "assign sum = a ^ b;\n",
      "assign carry = a & b;\n",
      "\n",
      "endmodule\n"
     ]
    }
   ],
   "source": [
    "prompt = \"//a half adder module \"\n",
    "# prompt=\"// Design a 2-to-1 multiplexer.\\n module mux( \\n    input [4:0] a, b,\\n    input sel,\\n    output [4:0] out );\\n // When sel=0, assign a to out. \\n// When sel=1, assign b to out.\"\n",
    "n_steps = 1\n",
    "n = 5\n",
    "end_pattern='endmodule'\n",
    "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n",
    "\n",
    "sample = model.generate(input_ids, max_length=128, temperature=0.5, top_p=0.9)\n",
    "print(tokenizer.decode(sample[0], truncate_before_pattern=[r\"endmodule\"]) + end_pattern)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "412aff69-6378-4628-89ad-bd1acbbda190",
   "metadata": {},
   "source": [
    "### Sampling Verilog on a per token basis for debugging purpose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "9420f589-64cd-47eb-8408-9bd9ca47c95b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Input': tensor([[ 1003,    64,  2063,   751,   263,  8265,   220,   198,   198, 21412,\n",
      "          2063,    62, 26676,     7,   198, 50284, 15414,   257,    11,   198,\n",
      "         50284, 15414,   275,    11,   198, 50284, 22915,  2160,    11,   198,\n",
      "         50284, 22915,  3283,   198, 50284,  1776,   198,   198,   562,   570,\n",
      "          2160,   796,   257, 10563,   275,    26,   198,   562,   570,  3283,\n",
      "           796,   257,  1222,   275,    26,   198,   198,   437, 21412,   198]],\n",
      "       device='cuda:0')}\n",
      "//a half adder module \n",
      "\n",
      "module half_adder(\n",
      "    input a,\n",
      "    input b,\n",
      "    output sum,\n",
      "    output carry\n",
      "    );\n",
      "\n",
      "assign sum = a ^ b;\n",
      "assign carry = a & b;\n",
      "\n",
      "endmodule\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# max_length = \n",
    "\n",
    "# stop_word = \n",
    "iteration = dict()\n",
    "with torch.no_grad():\n",
    "    while True:\n",
    "        \n",
    "        iteration[\"Input\"] = input_ids\n",
    "        \n",
    "        output = model(input_ids)\n",
    "\n",
    "        # select logits of the first batch and the last token and apply softmax\n",
    "        next_token_logits = output.logits[0, -1, :]\n",
    "        next_token_probs = torch.softmax(next_token_logits, dim=-1)\n",
    "        sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)\n",
    "        # print(sorted_ids)\n",
    "        \n",
    "        \n",
    "        # store tokens with highest probability\n",
    "        for choice_idx in range(n):\n",
    "            token_id = sorted_ids[choice_idx]\n",
    "            \n",
    "            token_prob = next_token_probs[token_id].cpu().numpy()\n",
    "            # print(token_id, token_prob)\n",
    "            # print(tokenizer.decode(token_id.view(1,-1)))\n",
    "            \n",
    "            token_choice = (              \n",
    "                f\"{tokenizer.decode(token_id)} ({100*token_prob:.2f}%)\"\n",
    "            )\n",
    "            \n",
    "            iteration[f\"{choice_idx+1}\"] = token_choice\n",
    "\n",
    "        #Append predictions to list\n",
    "        input_ids = torch.cat([input_ids, sorted_ids[None,0, None]], dim=-1)\n",
    "        \n",
    "        # condition checks on stop_words detected\n",
    "        if 'endmodule' in tokenizer.batch_decode(input_ids)[0]: break\n",
    "        \n",
    "print(tokenizer.batch_decode(input_ids)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb325b8f-6d6c-4cfb-b76c-f84d2225ade3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "codegen",
   "language": "python",
   "name": "codegen"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
